/**
 * Copyright 2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "c/ddk/graph/model.h"
#include "c/ddk/graph/context.h"

#include <memory>
#include <string>

#include "resource_manager.h"
#include "framework/infra/log/log.h"
#include "graph/model.h"
#include "c/ddk/model_manager/hiai_model_builder.h"
#include "graph/buffer.h"
#include "base/common/file_util/file_util.h"

ModelHandle HIAI_IR_CreateModel(ResMgrHandle resMgr, const char* name)
{
    if (resMgr == nullptr || name == nullptr) {
        FMK_LOGE("resMgr or name is invalid");
        return nullptr;
    }
    BasePtr model = std::make_shared<ge::Model>(std::string(name), "version");
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    resMgrPtr->StoreSrcPtr(model);
    return model.get();
}

HIAI_Status HIAI_IR_CreateBuiltModel(ResMgrHandle resMgr, ConstModelHandle model,
    const HIAI_MR_ModelBuildOptions* options, HIAI_MR_BuiltModel** builtModel)
{
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr modelBasePtr = resMgrPtr->GetSrcPtr(model);
    if (modelBasePtr == nullptr) {
        FMK_LOGE("get model base ptr failed");
        return HIAI_FAILURE;
    }
    std::shared_ptr<ge::Model> modelPtr = std::dynamic_pointer_cast<ge::Model>(modelBasePtr);
    ge::Buffer buffer;
    if (modelPtr->Save(buffer) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("save model failed.");
        return HIAI_FAILURE;
    }
    void* data = static_cast<void*>(const_cast<uint8_t*>(buffer.GetData()));
    size_t size = buffer.GetSize();
    if (HIAI_MR_ModelBuilder_Build(options, modelPtr->GetName().c_str(), data, size, builtModel) != HIAI_SUCCESS) {
        FMK_LOGE("build model failed.");
        return HIAI_FAILURE;
    }

    return HIAI_SUCCESS;
}

HIAI_Status HIAI_IR_DumpModel(ResMgrHandle resMgr, ConstModelHandle model, const char* file)
{
    if (resMgr == nullptr) {
        FMK_LOGE("resMgr is nullptr");
        return HIAI_FAILURE;
    }
    if (model == nullptr) {
        FMK_LOGE("model is nullptr");
        return HIAI_FAILURE;
    }
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    // 根据传入的model对象裸指针，从资源管理器里获取c++ model对象
    BasePtr modelBasePtr = resMgrPtr->GetSrcPtr(model);
    if (modelBasePtr == nullptr) {
        FMK_LOGE("get model base ptr failed");
        return HIAI_FAILURE;
    }
    // 从BasePtr转成ModelPtr，调用相应c++ save实现
    std::shared_ptr<ge::Model> modelPtr = std::dynamic_pointer_cast<ge::Model>(modelBasePtr);
    if (modelPtr->Dump(file) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("save model failed.");
        return HIAI_FAILURE;
    }
    return HIAI_SUCCESS;
}

HIAI_Status HIAI_IR_SetGraph(ResMgrHandle resMgr, ConstGraphHandle graph, ModelHandle model)
{
    if (resMgr == nullptr || model == nullptr || graph == nullptr) {
        FMK_LOGE("resMgr, model or graph is nullptr");
        return HIAI_FAILURE;
    }
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr modelBasePtr = resMgrPtr->GetSrcPtr(model);
    if (modelBasePtr == nullptr) {
        FMK_LOGE("get model base ptr failed");
        return HIAI_FAILURE;
    }
    std::shared_ptr<ge::Model> modelPtr = std::dynamic_pointer_cast<ge::Model>(modelBasePtr);

    BasePtr graphBasePtr = resMgrPtr->GetSrcPtr(graph);
    if (graphBasePtr == nullptr) {
        FMK_LOGE("get graph base ptr failed");
        return HIAI_FAILURE;
    }
    std::shared_ptr<ge::Graph> graphPtr = std::dynamic_pointer_cast<ge::Graph>(graphBasePtr);
    modelPtr->SetGraph(*graphPtr);
    if (!modelPtr->IsValid()) {
        FMK_LOGE("ir_model is invalid.");
        return HIAI_FAILURE;
    }
    return HIAI_SUCCESS;
}

HIAI_Status HIAI_IR_SaveModel(ResMgrHandle resMgr, ConstModelHandle model, const char* file)
{
    if (resMgr == nullptr) {
        FMK_LOGE("resMgr is nullptr");
        return HIAI_FAILURE;
    }
    if (model == nullptr) {
        FMK_LOGE("model is nullptr");
        return HIAI_FAILURE;
    }
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr modelBasePtr = resMgrPtr->GetSrcPtr(model);
    if (modelBasePtr == nullptr) {
        FMK_LOGE("get model base ptr failed");
        return HIAI_FAILURE;
    }
    std::shared_ptr<ge::Model> modelPtr = std::dynamic_pointer_cast<ge::Model>(modelBasePtr);

    ge::Buffer modelbuffer;
    if (modelPtr->SaveFullModel(modelbuffer) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("SaveFullModel failed.");
        return HIAI_FAILURE;
    }
    if (ge::FileUtil::WriteBufferToFile(
        reinterpret_cast<const void*>(modelbuffer.GetData()), modelbuffer.GetSize(), file) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("ModelFileSaver Save file failed.");
        return HIAI_FAILURE;
    }

    return HIAI_SUCCESS;
}